{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to Fine-Tune Gemma with ChatML and Hugging Face TRL\n",
    "\n",
    "Last week, Google released [Gemma](https://huggingface.co/blog/gemma), a new family of state-of-the-art open LLMs. Gemma comes in two sizes: 7B parameters, for efficient deployment and development on consumer-size GPU and TPU and 2B versions for CPU and on-device applications. Both come in base and instruction-tuned variants.\n",
    "\n",
    "After the first week it seemed that Gemma is not very friendly to fine-tune using the [ChatML format](https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/ai-services/openai/includes/chat-markup-language.md#working-with-chat-markup-language-chatml), which is adapted and used by the open soruce community, e.g. OpenHermes or Dolphin. I created this blog post to show you how to fine-tune Gemma using ChatML and Hugging Face TRL. \n",
    "\n",
    "This blog post is derived from my [How to Fine-Tune LLMs in 2024 with Hugging Face](https://www.philschmid.de/fine-tune-llms-in-2024-with-trl) blog tailored to fine-tune Gemma 7B. We will use Hugging Face [TRL](https://huggingface.co/docs/trl/index), [Transformers](https://huggingface.co/docs/transformers/index) & [datasets](https://huggingface.co/docs/datasets/index).\n",
    "\n",
    "1. Setup development environment\n",
    "2. Create and prepare the dataset\n",
    "3. Fine-tune LLM using `trl` and the `SFTTrainer` \n",
    "4. Test and evaluate the LLM\n",
    "\n",
    "_Note: This blog was created to run on consumer size GPUs (24GB), e.g. NVIDIA A10G or RTX 4090/3090, but can be easily adapted to run on bigger GPUs._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup development environment\n",
    "\n",
    "Our first step is to install Hugging Face Libraries and Pyroch, including trl, transformers and datasets. If you haven't heard of trl yet, don't worry. It is a new library on top of transformers and datasets, which makes it easier to fine-tune, rlhf, align open LLMs. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install Pytorch & other libraries\n",
    "!pip install \"torch==2.1.2\" tensorboard\n",
    "\n",
    "# Install Hugging Face libraries\n",
    "!pip install  --upgrade \\\n",
    "  \"transformers==4.38.2\" \\\n",
    "  \"datasets==2.16.1\" \\\n",
    "  \"accelerate==0.26.1\" \\\n",
    "  \"evaluate==0.4.1\" \\\n",
    "  \"bitsandbytes==0.42.0\" \\\n",
    "  \"trl==0.7.11\" \\\n",
    "  \"peft==0.8.2\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you are using a GPU with Ampere architecture (e.g. NVIDIA A10G or RTX 4090/3090) or newer you can use Flash attention. Flash Attention is a an method that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length. The TL;DR; accelerates training up to 3x. Learn more at [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/main).\n",
    "\n",
    "_Note: If your machine has less than 96GB of RAM and lots of CPU cores, reduce the number of `MAX_JOBS`. On the `g5.2xlarge` we used `4`._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'\n",
    "# install flash-attn\n",
    "!pip install ninja packaging\n",
    "!MAX_JOBS=4 pip install flash-attn --no-build-isolation --upgrade"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "_Installing flash attention from source can take quite a bit of time (10-45 minutes)._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will also need to login into our [Hugging Face](https://huggingface.co) account to be able to access Gemma. To use Gemma you first need to agree to the terms of use. You can do this by visiting the [Gemma page](https://huggingface.co/google/gemma-7b) following the gate mechanism. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import login\n",
    "\n",
    "login(\n",
    "  token=\"\", # ADD YOUR TOKEN HERE\n",
    "  add_to_git_credential=True\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Create and prepare the dataset\n",
    "\n",
    "We are not going to focus on creating a dataset in this blog post. If you want to learn more about creating a dataset, I recommend reading the [How to Fine-Tune LLMs in 2024 with Hugging Face](https://www.philschmid.de/fine-tune-llms-in-2024-with-trl) blog post. We are going to use the Databricks Dolly datatset, formated already as messages. This means we can use the `conversational` format to fine-tune our model.\n",
    "\n",
    "```json\n",
    "{\"messages\": [{\"role\": \"system\", \"content\": \"You are...\"}, {\"role\": \"user\", \"content\": \"...\"}, {\"role\": \"assistant\", \"content\": \"...\"}]}\n",
    "{\"messages\": [{\"role\": \"system\", \"content\": \"You are...\"}, {\"role\": \"user\", \"content\": \"...\"}, {\"role\": \"assistant\", \"content\": \"...\"}]}\n",
    "{\"messages\": [{\"role\": \"system\", \"content\": \"You are...\"}, {\"role\": \"user\", \"content\": \"...\"}, {\"role\": \"assistant\", \"content\": \"...\"}]}\n",
    "```\n",
    "\n",
    "The latest release of `trl` supports the conversation dataset formats. This means don't need to do any additional formatting of the dataset. We can use the dataset as is."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "# Load Dolly Dataset.\n",
    "dataset = load_dataset(\"philschmid/dolly-15k-oai-style\", split=\"train\")\n",
    "\n",
    "print(dataset[3][\"messages\"])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Fine-tune LLM using `trl` and the `SFTTrainer` \n",
    "\n",
    "We will use the [SFTTrainer](https://huggingface.co/docs/trl/sft_trainer) from `trl` to fine-tune our model. The `SFTTrainer` makes it straightfoward to supervise fine-tune open LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation, and checkpointing, but adds additiional quality of life features, including:\n",
    "* Dataset formatting, including conversational and instruction format\n",
    "* Training on completions only, ignoring prompts\n",
    "* Packing datasets for more efficient training\n",
    "* PEFT (parameter-efficient fine-tuning) support including Q-LoRA\n",
    "* Preparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)\n",
    " \n",
    "We will use the dataset formatting, packing and PEFT features in our example. As peft method we will use [QLoRA](https://arxiv.org/abs/2305.14314) a technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance by using quantization.\n",
    "\n",
    "_Note: Gemma comes with a big vocabulary of ~250,000 tokens. Normally if you want to fine-tune LLMs on the ChatML format you would need to add special tokens to the tokenizer and model and teach to understand the different roles in a conversation. But Google included ~100 placeholder tokens in the vocabulary, which we can replace with special tokens, like `<|im_start|>` and `<|im_end|>`. I created a Tokenizer for the ChatML format [philschmid/gemma-tokenizer-chatml](https://huggingface.co/philschmid/gemma-tokenizer-chatml) which you can use to fine-tune Gemma with ChatML._\n",
    "\n",
    "The Chat template used during fine-tuning is not 100% compatible with the ChatML format. Since [Google/gemma-7b](https://huggingface.co/google/gemma-7b) requires inputs always to start with a `<bos>` token. This means our inputs will look like. \n",
    "\n",
    "```xml\n",
    "<bos><|im_start|>system\n",
    "You are Gemma.<|im_end|>\n",
    "<|im_start|>user\n",
    "Hello, how are you?<|im_end|>\n",
    "<|im_start|>assistant\n",
    "I'm doing great. How can I help you today?<|im_end|>\\n<eos>\n",
    "```\n",
    "\n",
    "_Note: We are not having an idea why Gemma needs `<bos>` token at the beginning of the input._\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
    "\n",
    "# Hugging Face model id\n",
    "model_id = \"google/gemma-7b\"\n",
    "tokenizer_id = \"philschmid/gemma-tokenizer-chatml\"\n",
    "\n",
    "# BitsAndBytesConfig int-4 config\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type=\"nf4\", bnb_4bit_compute_dtype=torch.bfloat16\n",
    ")\n",
    "\n",
    "# Load model and tokenizer\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_id,\n",
    "    device_map=\"auto\",\n",
    "    attn_implementation=\"flash_attention_2\",\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    quantization_config=bnb_config\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)\n",
    "tokenizer.padding_side = 'right' # to prevent warnings"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `SFTTrainer`  supports a native integration with `peft`, which makes it super easy to efficiently tune LLMs using, e.g. QLoRA. We only need to create our `LoraConfig` and provide it to the trainer. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import LoraConfig\n",
    "\n",
    "# LoRA config based on QLoRA paper & Sebastian Raschka experiment\n",
    "peft_config = LoraConfig(\n",
    "        lora_alpha=8,\n",
    "        lora_dropout=0.05,\n",
    "        r=6,\n",
    "        bias=\"none\",\n",
    "        target_modules=\"all-linear\",\n",
    "        task_type=\"CAUSAL_LM\", \n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we can start our training we need to define the hyperparameters (`TrainingArguments`) we want to use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    "\n",
    "args = TrainingArguments(\n",
    "    output_dir=\"gemma-7b-dolly-chatml\", # directory to save and repository id\n",
    "    num_train_epochs=3,                     # number of training epochs\n",
    "    per_device_train_batch_size=2,          # batch size per device during training\n",
    "    gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass\n",
    "    gradient_checkpointing=True,            # use gradient checkpointing to save memory\n",
    "    optim=\"adamw_torch_fused\",              # use fused adamw optimizer\n",
    "    logging_steps=10,                       # log every 10 steps\n",
    "    save_strategy=\"epoch\",                  # save checkpoint every epoch\n",
    "    bf16=True,                              # use bfloat16 precision\n",
    "    tf32=True,                              # use tf32 precision\n",
    "    learning_rate=2e-4,                     # learning rate, based on QLoRA paper\n",
    "    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper\n",
    "    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper\n",
    "    lr_scheduler_type=\"constant\",           # use constant learning rate scheduler\n",
    "    push_to_hub=False,                       # push model to hub\n",
    "    report_to=\"tensorboard\",                # report metrics to tensorboard\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now have every building block we need to create our `SFTTrainer` to start then training our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from trl import SFTTrainer\n",
    "\n",
    "max_seq_length = 1512 # max sequence length for model and packing of the dataset\n",
    "\n",
    "trainer = SFTTrainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=dataset,\n",
    "    peft_config=peft_config,\n",
    "    max_seq_length=max_seq_length,\n",
    "    tokenizer=tokenizer,\n",
    "    packing=True,\n",
    "    dataset_kwargs={\n",
    "        \"add_special_tokens\": False, # We template with special tokens\n",
    "        \"append_concat_token\": False, # No need to add additional separator token\n",
    "    }\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Start training our model by calling the `train()` method on our `Trainer` instance. This will start the training loop and train our model for 3 epochs. Since we are using a PEFT method, we will only save the adapted model weights and not the full model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# start training, the model will be automatically saved to the hub and the output directory\n",
    "trainer.train()\n",
    "\n",
    "# save model \n",
    "trainer.save_model()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The training with Flash Attention for 3 epochs with a dataset of 15k samples took 4:14:36 on a `g5.2xlarge`. The instance costs `1.21$/h` which brings us to a total cost of only ~`5.3$`.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optional: Merge LoRA adapter in to the original model\n",
    "\n",
    "When using QLoRA, we only train adapters and not the full model. This means when saving the model during training we only save the adapter weights and not the full model. If you want to save the full model, which makes it easier to use with Text Generation Inference you can merge the adapter weights into the model weights using the `merge_and_unload` method and then save the model with the save_pretrained method. \n",
    "\n",
    "Check out the [How to Fine-Tune LLMs in 2024 with Hugging Face](https://www.philschmid.de/fine-tune-llms-in-2024-with-trl#optional-merge-lora-adapter-in-to-the-original-model) blog post on how to do it .\n",
    "\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Test Model and run Inference\n",
    "\n",
    "After the training is done we want to evaluate and test our model. We will load different samples from the original dataset and evaluate the model on those samples, using a simple loop and accuracy as our metric. \n",
    "\n",
    "_Note: Evaluating Generative AI models is not a trivial task since 1 input can have multiple correct outputs. If you want to learn more about evaluating generative models, check out [Evaluate LLMs and RAG a practical example using Langchain and Hugging Face](https://www.philschmid.de/evaluate-llm) blog post._\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# free the memory again\n",
    "del model\n",
    "del trainer\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load the adapted model and the tokenize into the `pipeline` to easily test it and extract the token id of `<|im_end|>` to use it in the `generate` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from peft import AutoPeftModelForCausalLM\n",
    "from transformers import  AutoTokenizer, pipeline\n",
    "\n",
    "peft_model_id = \"gemma-7b-dolly-chatml\"\n",
    "\n",
    "# Load Model with PEFT adapter\n",
    "tokenizer = AutoTokenizer.from_pretrained(peft_model_id)\n",
    "model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id, device_map=\"auto\", torch_dtype=torch.float16)\n",
    "pipe = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)\n",
    "# get token id for end of conversation\n",
    "eos_token = tokenizer(\"<|im_end|>\",add_special_tokens=False)[\"input_ids\"][0]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets test some prompt samples and see how the model performs. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    prompt:\n",
      "What is the capital of Germany? Explain why thats the case and if it was different in the past?\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/pytorch/lib/python3.10/site-packages/transformers/pipelines/base.py:1157: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    response:\n",
      "The capital of Germany is Berlin. Berlin is the capital of Germany because it is the largest city in Germany, and it is the political center of the country. Berlin is also home to the German parliament, the Bundestag, and the federal government. In the past, the capital of Germany was Berlin, but it was divided into two parts during the Cold War. The western part of Berlin was controlled by the United States, France, and the United Kingdom, while the eastern part was controlled by the Soviet Union. Berlin was also the site of the Berlin Wall, which separated the two sides of the city from 1961 to 1989. The fall of the Berlin Wall marked the end of the Cold War, and Germany was reunified in 1990. Berlin has since become the capital of a reunified Germany.\n",
      "--------------------------------------------------\n",
      "    prompt:\n",
      "Write a Python function to calculate the factorial of a number.\n",
      "    response:\n",
      "def factorial(n):\n",
      "    fact = 1\n",
      "    for i in range(1, n+1):\n",
      "        fact = fact * i\n",
      "    return fact\n",
      "--------------------------------------------------\n",
      "    prompt:\n",
      "A rectangular garden has a length of 25 feet and a width of 15 feet. If you want to build a fence around the entire garden, how many feet of fencing will you need?\n",
      "    response:\n",
      "The total amount of fencing needed is 70 feet. This is determined by multiplying the perimeter of the garden (which is 70 feet) by the cost per foot of the fencing material.\n",
      "--------------------------------------------------\n",
      "    prompt:\n",
      "What is the difference between a fruit and a vegetable? Give examples of each.\n",
      "    response:\n",
      "A fruit is the part of a plant that contains the seeds. Examples of fruits are apples, oranges, strawberries, and grapes. A vegetable is the part of a plant that is used for food but does not contain seeds. Examples of vegetables are carrots, potatoes, broccoli, and lettuce.\n",
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "prompts = [\n",
    "    \"What is the capital of Germany? Explain why thats the case and if it was different in the past?\",\n",
    "    \"Write a Python function to calculate the factorial of a number.\",\n",
    "    \"A rectangular garden has a length of 25 feet and a width of 15 feet. If you want to build a fence around the entire garden, how many feet of fencing will you need?\",\n",
    "    \"What is the difference between a fruit and a vegetable? Give examples of each.\",\n",
    "]\n",
    "\n",
    "def test_inference(prompt):\n",
    "    prompt = pipe.tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": prompt}], tokenize=False, add_generation_prompt=True)\n",
    "    outputs = pipe(prompt, max_new_tokens=1024, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, eos_token_id=eos_token)\n",
    "    return outputs[0]['generated_text'][len(prompt):].strip()\n",
    "\n",
    "\n",
    "for prompt in prompts:\n",
    "    print(f\"    prompt:\\n{prompt}\")\n",
    "    print(f\"    response:\\n{test_inference(prompt)}\")\n",
    "    print(\"-\"*50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "2d58e898dde0263bc564c6968b04150abacfd33eed9b19aaa8e45c040360e146"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
